# Owner(s): ["module: pytree"]

import collections
import inspect
import re
import unittest
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
from dataclasses import dataclass
from typing import Any, NamedTuple

import torch
import torch.utils._pytree as py_pytree
from torch.fx.immutable_collections import immutable_dict, immutable_list
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    IS_FBCODE,
    parametrize,
    run_tests,
    skipIfTorchDynamo,
    subtest,
    TEST_WITH_TORCHDYNAMO,
    TestCase,
)

if IS_FBCODE:
    # optree is not yet enabled in fbcode, so just re-test the python implementation
    cxx_pytree = py_pytree
else:
    import torch.utils._cxx_pytree as cxx_pytree

GlobalPoint = namedtuple("GlobalPoint", ["x", "y"])


class GlobalDummyType:
    def __init__(self, x, y):
        self.x = x
        self.y = y


class TestGenericPytree(TestCase):
    def test_aligned_public_apis(self):
        public_apis = py_pytree.__all__

        self.assertEqual(public_apis, cxx_pytree.__all__)

        for name in public_apis:
            cxx_api = getattr(cxx_pytree, name)
            py_api = getattr(py_pytree, name)

            self.assertEqual(inspect.isclass(cxx_api), inspect.isclass(py_api))
            self.assertEqual(inspect.isfunction(cxx_api), inspect.isfunction(py_api))
            if inspect.isfunction(cxx_api):
                cxx_signature = inspect.signature(cxx_api)
                py_signature = inspect.signature(py_api)

                # Check the parameter names are the same.
                cxx_param_names = list(cxx_signature.parameters)
                py_param_names = list(py_signature.parameters)
                self.assertEqual(cxx_param_names, py_param_names)

                # Check the positional parameters are the same.
                cxx_positional_param_names = [
                    n
                    for n, p in cxx_signature.parameters.items()
                    if (
                        p.kind
                        in {
                            inspect.Parameter.POSITIONAL_ONLY,
                            inspect.Parameter.POSITIONAL_OR_KEYWORD,
                        }
                    )
                ]
                py_positional_param_names = [
                    n
                    for n, p in py_signature.parameters.items()
                    if (
                        p.kind
                        in {
                            inspect.Parameter.POSITIONAL_ONLY,
                            inspect.Parameter.POSITIONAL_OR_KEYWORD,
                        }
                    )
                ]
                self.assertEqual(cxx_positional_param_names, py_positional_param_names)

                for py_name, py_param in py_signature.parameters.items():
                    self.assertIn(py_name, cxx_signature.parameters)
                    cxx_param = cxx_signature.parameters[py_name]

                    # Check parameter kinds and default values are the same.
                    self.assertEqual(cxx_param.kind, py_param.kind)
                    self.assertEqual(cxx_param.default, py_param.default)

                    # Check parameter annotations are the same.
                    if "TreeSpec" in str(cxx_param.annotation):
                        self.assertIn("TreeSpec", str(py_param.annotation))
                        self.assertEqual(
                            re.sub(
                                r"(?:\b)([\w\.]*)TreeSpec(?:\b)",
                                "TreeSpec",
                                str(cxx_param.annotation),
                            ),
                            re.sub(
                                r"(?:\b)([\w\.]*)TreeSpec(?:\b)",
                                "TreeSpec",
                                str(py_param.annotation),
                            ),
                            msg=(
                                f"C++ parameter {cxx_param} "
                                f"does not match Python parameter {py_param} "
                                f"for API `{name}`"
                            ),
                        )
                    else:
                        self.assertEqual(
                            cxx_param.annotation,
                            py_param.annotation,
                            msg=(
                                f"C++ parameter {cxx_param} "
                                f"does not match Python parameter {py_param} "
                                f"for API `{name}`"
                            ),
                        )

    @parametrize(
        "pytree_impl",
        [
            subtest(py_pytree, name="py"),
            subtest(cxx_pytree, name="cxx"),
        ],
    )
    def test_register_pytree_node(self, pytree_impl):
        class MyDict(UserDict):
            pass

        d = MyDict(a=1, b=2, c=3)

        # Custom types are leaf nodes by default
        values, spec = pytree_impl.tree_flatten(d)
        self.assertEqual(values, [d])
        self.assertIs(values[0], d)
        self.assertEqual(d, pytree_impl.tree_unflatten(values, spec))
        self.assertTrue(spec.is_leaf())

        # Register MyDict as a pytree node
        pytree_impl.register_pytree_node(
            MyDict,
            lambda d: (list(d.values()), list(d.keys())),
            lambda values, keys: MyDict(zip(keys, values)),
        )

        values, spec = pytree_impl.tree_flatten(d)
        self.assertEqual(values, [1, 2, 3])
        self.assertEqual(d, pytree_impl.tree_unflatten(values, spec))

        # Do not allow registering the same type twice
        with self.assertRaisesRegex(ValueError, "already registered"):
            pytree_impl.register_pytree_node(
                MyDict,
                lambda d: (list(d.values()), list(d.keys())),
                lambda values, keys: MyDict(zip(keys, values)),
            )

    @parametrize(
        "pytree_impl",
        [
            subtest(py_pytree, name="py"),
            subtest(cxx_pytree, name="cxx"),
        ],
    )
    def test_flatten_unflatten_leaf(self, pytree_impl):
        def run_test_with_leaf(leaf):
            values, treespec = pytree_impl.tree_flatten(leaf)
            self.assertEqual(values, [leaf])
            self.assertEqual(treespec, pytree_impl.LeafSpec())

            unflattened = pytree_impl.tree_unflatten(values, treespec)
            self.assertEqual(unflattened, leaf)

        run_test_with_leaf(1)
        run_test_with_leaf(1.0)
        run_test_with_leaf(None)
        run_test_with_leaf(bool)
        run_test_with_leaf(torch.randn(3, 3))

    @parametrize(
        "pytree_impl,gen_expected_fn",
        [
            subtest(
                (
                    py_pytree,
                    lambda tup: py_pytree.TreeSpec(
                        tuple, None, [py_pytree.LeafSpec() for _ in tup]
                    ),
                ),
                name="py",
            ),
            subtest(
                (cxx_pytree, lambda tup: cxx_pytree.tree_structure((0,) * len(tup))),
                name="cxx",
            ),
        ],
    )
    def test_flatten_unflatten_tuple(self, pytree_impl, gen_expected_fn):
        def run_test(tup):
            expected_spec = gen_expected_fn(tup)
            values, treespec = pytree_impl.tree_flatten(tup)
            self.assertIsInstance(values, list)
            self.assertEqual(values, list(tup))
            self.assertEqual(treespec, expected_spec)

            unflattened = pytree_impl.tree_unflatten(values, treespec)
            self.assertEqual(unflattened, tup)
            self.assertIsInstance(unflattened, tuple)

        run_test(())
        run_test((1.0,))
        run_test((1.0, 2))
        run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11))

    @parametrize(
        "pytree_impl,gen_expected_fn",
        [
            subtest(
                (
                    py_pytree,
                    lambda lst: py_pytree.TreeSpec(
                        list, None, [py_pytree.LeafSpec() for _ in lst]
                    ),
                ),
                name="py",
            ),
            subtest(
                (cxx_pytree, lambda lst: cxx_pytree.tree_structure([0] * len(lst))),
                name="cxx",
            ),
        ],
    )
    def test_flatten_unflatten_list(self, pytree_impl, gen_expected_fn):
        def run_test(lst):
            expected_spec = gen_expected_fn(lst)
            values, treespec = pytree_impl.tree_flatten(lst)
            self.assertIsInstance(values, list)
            self.assertEqual(values, lst)
            self.assertEqual(treespec, expected_spec)

            unflattened = pytree_impl.tree_unflatten(values, treespec)
            self.assertEqual(unflattened, lst)
            self.assertIsInstance(unflattened, list)

        run_test([])
        run_test([1.0, 2])
        run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11])

    @parametrize(
        "pytree_impl,gen_expected_fn",
        [
            subtest(
                (
                    py_pytree,
                    lambda dct: py_pytree.TreeSpec(
                        dict,
                        list(dct.keys()),
                        [py_pytree.LeafSpec() for _ in dct.values()],
                    ),
                ),
                name="py",
            ),
            subtest(
                (
                    cxx_pytree,
                    lambda dct: cxx_pytree.tree_structure(dict.fromkeys(dct, 0)),
                ),
                name="cxx",
            ),
        ],
    )
    def test_flatten_unflatten_dict(self, pytree_impl, gen_expected_fn):
        def run_test(dct):
            expected_spec = gen_expected_fn(dct)
            values, treespec = pytree_impl.tree_flatten(dct)
            self.assertIsInstance(values, list)
            self.assertEqual(values, list(dct.values()))
            self.assertEqual(treespec, expected_spec)

            unflattened = pytree_impl.tree_unflatten(values, treespec)
            self.assertEqual(unflattened, dct)
            self.assertIsInstance(unflattened, dict)

        run_test({})
        run_test({"a": 1})
        run_test({"abcdefg": torch.randn(2, 3)})
        run_test({1: torch.randn(2, 3)})
        run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)})

    @parametrize(
        "pytree_impl,gen_expected_fn",
        [
            subtest(
                (
                    py_pytree,
                    lambda odict: py_pytree.TreeSpec(
                        OrderedDict,
                        list(odict.keys()),
                        [py_pytree.LeafSpec() for _ in odict.values()],
                    ),
                ),
                name="py",
            ),
            subtest(
                (
                    cxx_pytree,
                    lambda odict: cxx_pytree.tree_structure(
                        OrderedDict.fromkeys(odict, 0)
                    ),
                ),
                name="cxx",
            ),
        ],
    )
    def test_flatten_unflatten_ordereddict(self, pytree_impl, gen_expected_fn):
        def run_test(odict):
            expected_spec = gen_expected_fn(odict)
            values, treespec = pytree_impl.tree_flatten(odict)
            self.assertIsInstance(values, list)
            self.assertEqual(values, list(odict.values()))
            self.assertEqual(treespec, expected_spec)

            unflattened = pytree_impl.tree_unflatten(values, treespec)
            self.assertEqual(unflattened, odict)
            self.assertIsInstance(unflattened, OrderedDict)

        od = OrderedDict()
        run_test(od)

        od["b"] = 1
        od["a"] = torch.tensor(3.14)
        run_test(od)

    @parametrize(
        "pytree_impl,gen_expected_fn",
        [
            subtest(
                (
                    py_pytree,
                    lambda ddct: py_pytree.TreeSpec(
                        defaultdict,
                        [ddct.default_factory, list(ddct.keys())],
                        [py_pytree.LeafSpec() for _ in ddct.values()],
                    ),
                ),
                name="py",
            ),
            subtest(
                (
                    cxx_pytree,
                    lambda ddct: cxx_pytree.tree_structure(
                        defaultdict(ddct.default_factory, dict.fromkeys(ddct, 0))
                    ),
                ),
                name="cxx",
            ),
        ],
    )
    def test_flatten_unflatten_defaultdict(self, pytree_impl, gen_expected_fn):
        def run_test(ddct):
            expected_spec = gen_expected_fn(ddct)
            values, treespec = pytree_impl.tree_flatten(ddct)
            self.assertIsInstance(values, list)
            self.assertEqual(values, list(ddct.values()))
            self.assertEqual(treespec, expected_spec)

            unflattened = pytree_impl.tree_unflatten(values, treespec)
            self.assertEqual(unflattened, ddct)
            self.assertEqual(unflattened.default_factory, ddct.default_factory)
            self.assertIsInstance(unflattened, defaultdict)

        run_test(defaultdict(list, {}))
        run_test(defaultdict(int, {"a": 1}))
        run_test(defaultdict(int, {"abcdefg": torch.randn(2, 3)}))
        run_test(defaultdict(int, {1: torch.randn(2, 3)}))
        run_test(defaultdict(int, {"a": 1, "b": 2, "c": torch.randn(2, 3)}))

    @parametrize(
        "pytree_impl,gen_expected_fn",
        [
            subtest(
                (
                    py_pytree,
                    lambda deq: py_pytree.TreeSpec(
                        deque, deq.maxlen, [py_pytree.LeafSpec() for _ in deq]
                    ),
                ),
                name="py",
            ),
            subtest(
                (
                    cxx_pytree,
                    lambda deq: cxx_pytree.tree_structure(
                        deque(deq, maxlen=deq.maxlen)
                    ),
                ),
                name="cxx",
            ),
        ],
    )
    def test_flatten_unflatten_deque(self, pytree_impl, gen_expected_fn):
        def run_test(deq):
            expected_spec = gen_expected_fn(deq)
            values, treespec = pytree_impl.tree_flatten(deq)
            self.assertIsInstance(values, list)
            self.assertEqual(values, list(deq))
            self.assertEqual(treespec, expected_spec)

            unflattened = pytree_impl.tree_unflatten(values, treespec)
            self.assertEqual(unflattened, deq)
            self.assertEqual(unflattened.maxlen, deq.maxlen)
            self.assertIsInstance(unflattened, deque)

        run_test(deque([]))
        run_test(deque([1.0, 2]))
        run_test(deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))

    @parametrize(
        "pytree_impl",
        [
            subtest(py_pytree, name="py"),
            subtest(cxx_pytree, name="cxx"),
        ],
    )
    def test_flatten_unflatten_namedtuple(self, pytree_impl):
        Point = namedtuple("Point", ["x", "y"])

        def run_test(tup):
            if pytree_impl is py_pytree:
                expected_spec = py_pytree.TreeSpec(
                    namedtuple, Point, [py_pytree.LeafSpec() for _ in tup]
                )
            else:
                expected_spec = cxx_pytree.tree_structure(Point(0, 1))
            values, treespec = pytree_impl.tree_flatten(tup)
            self.assertIsInstance(values, list)
            self.assertEqual(values, list(tup))
            self.assertEqual(treespec, expected_spec)

            unflattened = pytree_impl.tree_unflatten(values, treespec)
            self.assertEqual(unflattened, tup)
            self.assertIsInstance(unflattened, Point)

        run_test(Point(1.0, 2))
        run_test(Point(torch.tensor(1.0), 2))

    @parametrize(
        "op",
        [
            subtest(torch.max, name="max"),
            subtest(torch.min, name="min"),
        ],
    )
    @parametrize(
        "pytree_impl",
        [
            subtest(py_pytree, name="py"),
            subtest(cxx_pytree, name="cxx"),
        ],
    )
    def test_flatten_unflatten_return_types(self, pytree_impl, op):
        x = torch.randn(3, 3)
        expected = op(x, dim=0)

        values, spec = pytree_impl.tree_flatten(expected)
        # Check that values is actually List[Tensor] and not (ReturnType(...),)
        for value in values:
            self.assertIsInstance(value, torch.Tensor)
        result = pytree_impl.tree_unflatten(values, spec)

        self.assertEqual(type(result), type(expected))
        self.assertEqual(result, expected)

    @parametrize(
        "pytree_impl",
        [
            subtest(py_pytree, name="py"),
            subtest(cxx_pytree, name="cxx"),
        ],
    )
    def test_flatten_unflatten_nested(self, pytree_impl):
        def run_test(pytree):
            values, treespec = pytree_impl.tree_flatten(pytree)
            self.assertIsInstance(values, list)
            self.assertEqual(len(values), treespec.num_leaves)

            # NB: python basic data structures (dict list tuple) all have
            # contents equality defined on them, so the following works for them.
            unflattened = pytree_impl.tree_unflatten(values, treespec)
            self.assertEqual(unflattened, pytree)

        cases = [
            [()],
            ([],),
            {"a": ()},
            {"a": 0, "b": [{"c": 1}]},
            {"a": 0, "b": [1, {"c": 2}, torch.randn(3)], "c": (torch.randn(2, 3), 1)},
        ]
        for case in cases:
            run_test(case)

    @parametrize(
        "pytree_impl",
        [
            subtest(py_pytree, name="py"),
            subtest(cxx_pytree, name="cxx"),
        ],
    )
    def test_flatten_with_is_leaf(self, pytree_impl):
        def run_test(pytree, one_level_leaves):
            values, treespec = pytree_impl.tree_flatten(
                pytree, is_leaf=lambda x: x is not pytree
            )
            self.assertIsInstance(values, list)
            self.assertEqual(len(values), treespec.num_nodes - 1)
            self.assertEqual(len(values), treespec.num_leaves)
            self.assertEqual(len(values), treespec.num_children)
            self.assertEqual(values, one_level_leaves)

            self.assertEqual(
                treespec,
                pytree_impl.tree_structure(
                    pytree_impl.tree_unflatten([0] * treespec.num_leaves, treespec)
                ),
            )

            unflattened = pytree_impl.tree_unflatten(values, treespec)
            self.assertEqual(unflattened, pytree)

        cases = [
            ([()], [()]),
            (([],), [[]]),
            ({"a": ()}, [()]),
            ({"a": 0, "b": [{"c": 1}]}, [0, [{"c": 1}]]),
            (
                {
                    "a": 0,
                    "b": [1, {"c": 2}, torch.ones(3)],
                    "c": (torch.zeros(2, 3), 1),
                },
                [0, [1, {"c": 2}, torch.ones(3)], (torch.zeros(2, 3), 1)],
            ),
        ]
        for case in cases:
            run_test(*case)

    @parametrize(
        "pytree_impl",
        [
            subtest(py_pytree, name="py"),
            subtest(cxx_pytree, name="cxx"),
        ],
    )
    def test_tree_map(self, pytree_impl):
        def run_test(pytree):
            def f(x):
                return x * 3

            sm1 = sum(map(f, pytree_impl.tree_leaves(pytree)))
            sm2 = sum(pytree_impl.tree_leaves(pytree_impl.tree_map(f, pytree)))
            self.assertEqual(sm1, sm2)

            def invf(x):
                return x // 3

            self.assertEqual(
                pytree_impl.tree_map(invf, pytree_impl.tree_map(f, pytree)),
                pytree,
            )

        cases = [
            [()],
            ([],),
            {"a": ()},
            {"a": 1, "b": [{"c": 2}]},
            {"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)},
        ]
        for case in cases:
            run_test(case)

    @parametrize(
        "pytree_impl",
        [
            subtest(py_pytree, name="py"),
            subtest(cxx_pytree, name="cxx"),
        ],
    )
    def test_tree_map_multi_inputs(self, pytree_impl):
        def run_test(pytree):
            def f(x, y, z):
                return x, [y, (z, 0)]

            pytree_x = pytree
            pytree_y = pytree_impl.tree_map(lambda x: (x + 1,), pytree)
            pytree_z = pytree_impl.tree_map(lambda x: {"a": x * 2, "b": 2}, pytree)

            self.assertEqual(
                pytree_impl.tree_map(f, pytree_x, pytree_y, pytree_z),
                pytree_impl.tree_map(
                    lambda x: f(x, (x + 1,), {"a": x * 2, "b": 2}), pytree
                ),
            )

        cases = [
            [()],
            ([],),
            {"a": ()},
            {"a": 1, "b": [{"c": 2}]},
            {"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)},
        ]
        for case in cases:
            run_test(case)

    @parametrize(
        "pytree_impl",
        [
            subtest(py_pytree, name="py"),
            subtest(cxx_pytree, name="cxx"),
        ],
    )
    def test_tree_map_only(self, pytree_impl):
        self.assertEqual(
            pytree_impl.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]
        )

    @parametrize(
        "pytree_impl",
        [
            subtest(py_pytree, name="py"),
            subtest(cxx_pytree, name="cxx"),
        ],
    )
    def test_tree_map_only_predicate_fn(self, pytree_impl):
        self.assertEqual(
            pytree_impl.tree_map_only(lambda x: x == 0, lambda x: x + 2, [0, 1]), [2, 1]
        )

    @parametrize(
        "pytree_impl",
        [
            subtest(py_pytree, name="py"),
            subtest(cxx_pytree, name="cxx"),
        ],
    )
    def test_tree_all_any(self, pytree_impl):
        self.assertTrue(pytree_impl.tree_all(lambda x: x % 2, [1, 3]))
        self.assertFalse(pytree_impl.tree_all(lambda x: x % 2, [0, 1]))
        self.assertTrue(pytree_impl.tree_any(lambda x: x % 2, [0, 1]))
        self.assertFalse(pytree_impl.tree_any(lambda x: x % 2, [0, 2]))
        self.assertTrue(pytree_impl.tree_all_only(int, lambda x: x % 2, [1, 3, "a"]))
        self.assertFalse(pytree_impl.tree_all_only(int, lambda x: x % 2, [0, 1, "a"]))
        self.assertTrue(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 1, "a"]))
        self.assertFalse(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 2, "a"]))

    @parametrize(
        "pytree_impl",
        [
            subtest(py_pytree, name="py"),
            subtest(cxx_pytree, name="cxx"),
        ],
    )
    def test_broadcast_to_and_flatten(self, pytree_impl):
        cases = [
            (1, (), []),
            # Same (flat) structures
            ((1,), (0,), [1]),
            ([1], [0], [1]),
            ((1, 2, 3), (0, 0, 0), [1, 2, 3]),
            ({"a": 1, "b": 2}, {"a": 0, "b": 0}, [1, 2]),
            # Mismatched (flat) structures
            ([1], (0,), None),
            ([1], (0,), None),
            ((1,), [0], None),
            ((1, 2, 3), (0, 0), None),
            ({"a": 1, "b": 2}, {"a": 0}, None),
            ({"a": 1, "b": 2}, {"a": 0, "c": 0}, None),
            ({"a": 1, "b": 2}, {"a": 0, "b": 0, "c": 0}, None),
            # Same (nested) structures
            ((1, [2, 3]), (0, [0, 0]), [1, 2, 3]),
            ((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]),
            # Mismatched (nested) structures
            ((1, [2, 3]), (0, (0, 0)), None),
            ((1, [2, 3]), (0, [0, 0, 0]), None),
            # Broadcasting single value
            (1, (0, 0, 0), [1, 1, 1]),
            (1, [0, 0, 0], [1, 1, 1]),
            (1, {"a": 0, "b": 0}, [1, 1]),
            (1, (0, [0, [0]], 0), [1, 1, 1, 1]),
            (1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]),
            # Broadcast multiple things
            ((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]),
            ((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]),
            (([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]),
        ]
        for pytree, to_pytree, expected in cases:
            _, to_spec = pytree_impl.tree_flatten(to_pytree)
            result = pytree_impl._broadcast_to_and_flatten(pytree, to_spec)
            self.assertEqual(result, expected, msg=str([pytree, to_spec, expected]))

    @parametrize(
        "pytree_impl",
        [
            subtest(py_pytree, name="py"),
            subtest(cxx_pytree, name="cxx"),
        ],
    )
    def test_pytree_serialize_bad_input(self, pytree_impl):
        with self.assertRaises(TypeError):
            pytree_impl.treespec_dumps("random_blurb")


class TestPythonPytree(TestCase):
    def test_deprecated_register_pytree_node(self):
        class DummyType:
            def __init__(self, x, y):
                self.x = x
                self.y = y

        with self.assertWarnsRegex(
            UserWarning, "torch.utils._pytree._register_pytree_node"
        ):
            py_pytree._register_pytree_node(
                DummyType,
                lambda dummy: ([dummy.x, dummy.y], None),
                lambda xs, _: DummyType(*xs),
            )

        with self.assertWarnsRegex(UserWarning, "already registered"):
            py_pytree._register_pytree_node(
                DummyType,
                lambda dummy: ([dummy.x, dummy.y], None),
                lambda xs, _: DummyType(*xs),
            )

    def test_treespec_equality(self):
        self.assertEqual(
            py_pytree.LeafSpec(),
            py_pytree.LeafSpec(),
        )
        self.assertEqual(
            py_pytree.TreeSpec(list, None, []),
            py_pytree.TreeSpec(list, None, []),
        )
        self.assertEqual(
            py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
            py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
        )
        self.assertFalse(
            py_pytree.TreeSpec(tuple, None, []) == py_pytree.TreeSpec(list, None, []),
        )
        self.assertTrue(
            py_pytree.TreeSpec(tuple, None, []) != py_pytree.TreeSpec(list, None, []),
        )

    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
    def test_treespec_repr(self):
        # Check that it looks sane
        pytree = (0, [0, 0, [0]])
        _, spec = py_pytree.tree_flatten(pytree)
        self.assertEqual(
            repr(spec),
            (
                "TreeSpec(tuple, None, [*,\n"
                "  TreeSpec(list, None, [*,\n"
                "    *,\n"
                "    TreeSpec(list, None, [*])])])"
            ),
        )

    @unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
    def test_treespec_repr_dynamo(self):
        # Check that it looks sane
        pytree = (0, [0, 0, [0]])
        _, spec = py_pytree.tree_flatten(pytree)
        self.assertExpectedInline(
            repr(spec),
            """\
TreeSpec(tuple, None, [*,
  TreeSpec(list, None, [*,
    *,
    TreeSpec(list, None, [*])])])""",
        )

    @parametrize(
        "spec",
        [
            # py_pytree.tree_structure([])
            py_pytree.TreeSpec(list, None, []),
            # py_pytree.tree_structure(())
            py_pytree.TreeSpec(tuple, None, []),
            # py_pytree.tree_structure({})
            py_pytree.TreeSpec(dict, [], []),
            # py_pytree.tree_structure([0])
            py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
            # py_pytree.tree_structure([0, 1])
            py_pytree.TreeSpec(
                list,
                None,
                [
                    py_pytree.LeafSpec(),
                    py_pytree.LeafSpec(),
                ],
            ),
            # py_pytree.tree_structure((0, 1, 2))
            py_pytree.TreeSpec(
                tuple,
                None,
                [
                    py_pytree.LeafSpec(),
                    py_pytree.LeafSpec(),
                    py_pytree.LeafSpec(),
                ],
            ),
            # py_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
            py_pytree.TreeSpec(
                dict,
                ["a", "b", "c"],
                [
                    py_pytree.LeafSpec(),
                    py_pytree.LeafSpec(),
                    py_pytree.LeafSpec(),
                ],
            ),
            # py_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
            py_pytree.TreeSpec(
                OrderedDict,
                ["a", "b", "c"],
                [
                    py_pytree.TreeSpec(
                        tuple,
                        None,
                        [
                            py_pytree.LeafSpec(),
                            py_pytree.LeafSpec(),
                        ],
                    ),
                    py_pytree.LeafSpec(),
                    py_pytree.TreeSpec(
                        dict,
                        ["a", "b", "c"],
                        [
                            py_pytree.LeafSpec(),
                            py_pytree.LeafSpec(),
                            py_pytree.LeafSpec(),
                        ],
                    ),
                ],
            ),
            # py_pytree.tree_structure([(0, 1, [2, 3])])
            py_pytree.TreeSpec(
                list,
                None,
                [
                    py_pytree.TreeSpec(
                        tuple,
                        None,
                        [
                            py_pytree.LeafSpec(),
                            py_pytree.LeafSpec(),
                            py_pytree.TreeSpec(
                                list,
                                None,
                                [
                                    py_pytree.LeafSpec(),
                                    py_pytree.LeafSpec(),
                                ],
                            ),
                        ],
                    ),
                ],
            ),
            # py_pytree.tree_structure(defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}}))
            py_pytree.TreeSpec(
                defaultdict,
                [list, ["a", "b", "c"]],
                [
                    py_pytree.TreeSpec(
                        list,
                        None,
                        [
                            py_pytree.LeafSpec(),
                            py_pytree.LeafSpec(),
                        ],
                    ),
                    py_pytree.TreeSpec(
                        list,
                        None,
                        [
                            py_pytree.LeafSpec(),
                            py_pytree.LeafSpec(),
                        ],
                    ),
                    py_pytree.TreeSpec(dict, [], []),
                ],
            ),
        ],
    )
    def test_pytree_serialize(self, spec):
        # Ensure that the spec is valid
        self.assertEqual(
            spec,
            py_pytree.tree_structure(
                py_pytree.tree_unflatten([0] * spec.num_leaves, spec)
            ),
        )

        serialized_spec = py_pytree.treespec_dumps(spec)
        self.assertIsInstance(serialized_spec, str)
        self.assertEqual(spec, py_pytree.treespec_loads(serialized_spec))

    def test_pytree_serialize_namedtuple(self):
        Point = namedtuple("Point", ["x", "y"])
        spec = py_pytree.TreeSpec(
            namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
        )

        roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
        # The context in the namedtuple is different now because we recreated
        # the namedtuple type.
        self.assertEqual(spec.context._fields, roundtrip_spec.context._fields)

    def test_pytree_custom_type_serialize_bad(self):
        class DummyType:
            def __init__(self, x, y):
                self.x = x
                self.y = y

        py_pytree.register_pytree_node(
            DummyType,
            lambda dummy: ([dummy.x, dummy.y], None),
            lambda xs, _: DummyType(*xs),
        )

        spec = py_pytree.TreeSpec(
            DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
        )
        with self.assertRaisesRegex(
            NotImplementedError, "No registered serialization name"
        ):
            roundtrip_spec = py_pytree.treespec_dumps(spec)

    def test_pytree_custom_type_serialize(self):
        class DummyType:
            def __init__(self, x, y):
                self.x = x
                self.y = y

        py_pytree.register_pytree_node(
            DummyType,
            lambda dummy: ([dummy.x, dummy.y], None),
            lambda xs, _: DummyType(*xs),
            serialized_type_name="test_pytree_custom_type_serialize.DummyType",
            to_dumpable_context=lambda context: "moo",
            from_dumpable_context=lambda dumpable_context: None,
        )
        spec = py_pytree.TreeSpec(
            DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
        )
        serialized_spec = py_pytree.treespec_dumps(spec, 1)
        self.assertIn("moo", serialized_spec)
        roundtrip_spec = py_pytree.treespec_loads(serialized_spec)
        self.assertEqual(roundtrip_spec, spec)

    def test_pytree_serialize_register_bad(self):
        class DummyType:
            def __init__(self, x, y):
                self.x = x
                self.y = y

        with self.assertRaisesRegex(
            ValueError, "Both to_dumpable_context and from_dumpable_context"
        ):
            py_pytree.register_pytree_node(
                DummyType,
                lambda dummy: ([dummy.x, dummy.y], None),
                lambda xs, _: DummyType(*xs),
                serialized_type_name="test_pytree_serialize_register_bad.DummyType",
                to_dumpable_context=lambda context: "moo",
            )

    def test_pytree_context_serialize_bad(self):
        class DummyType:
            def __init__(self, x, y):
                self.x = x
                self.y = y

        py_pytree.register_pytree_node(
            DummyType,
            lambda dummy: ([dummy.x, dummy.y], None),
            lambda xs, _: DummyType(*xs),
            serialized_type_name="test_pytree_serialize_serialize_bad.DummyType",
            to_dumpable_context=lambda context: DummyType,
            from_dumpable_context=lambda dumpable_context: None,
        )

        spec = py_pytree.TreeSpec(
            DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
        )

        with self.assertRaisesRegex(
            TypeError, "Object of type type is not JSON serializable"
        ):
            py_pytree.treespec_dumps(spec)

    def test_pytree_serialize_bad_protocol(self):
        import json

        Point = namedtuple("Point", ["x", "y"])
        spec = py_pytree.TreeSpec(
            namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
        )

        with self.assertRaisesRegex(ValueError, "Unknown protocol"):
            py_pytree.treespec_dumps(spec, -1)

        serialized_spec = py_pytree.treespec_dumps(spec)
        protocol, data = json.loads(serialized_spec)
        bad_protocol_serialized_spec = json.dumps((-1, data))

        with self.assertRaisesRegex(ValueError, "Unknown protocol"):
            py_pytree.treespec_loads(bad_protocol_serialized_spec)

    def test_saved_serialized(self):
        # py_pytree.tree_structure(OrderedDict([(1, (0, 1)), (2, 2), (3, {4: 3, 5: 4, 6: 5})]))
        complicated_spec = py_pytree.TreeSpec(
            OrderedDict,
            [1, 2, 3],
            [
                py_pytree.TreeSpec(
                    tuple, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
                ),
                py_pytree.LeafSpec(),
                py_pytree.TreeSpec(
                    dict,
                    [4, 5, 6],
                    [
                        py_pytree.LeafSpec(),
                        py_pytree.LeafSpec(),
                        py_pytree.LeafSpec(),
                    ],
                ),
            ],
        )
        # Ensure that the spec is valid
        self.assertEqual(
            complicated_spec,
            py_pytree.tree_structure(
                py_pytree.tree_unflatten(
                    [0] * complicated_spec.num_leaves, complicated_spec
                )
            ),
        )

        serialized_spec = py_pytree.treespec_dumps(complicated_spec)
        saved_spec = (
            '[1, {"type": "collections.OrderedDict", "context": "[1, 2, 3]", '
            '"children_spec": [{"type": "builtins.tuple", "context": "null", '
            '"children_spec": [{"type": null, "context": null, '
            '"children_spec": []}, {"type": null, "context": null, '
            '"children_spec": []}]}, {"type": null, "context": null, '
            '"children_spec": []}, {"type": "builtins.dict", "context": '
            '"[4, 5, 6]", "children_spec": [{"type": null, "context": null, '
            '"children_spec": []}, {"type": null, "context": null, "children_spec": '
            '[]}, {"type": null, "context": null, "children_spec": []}]}]}]'
        )
        self.assertEqual(serialized_spec, saved_spec)
        self.assertEqual(complicated_spec, py_pytree.treespec_loads(saved_spec))

    def test_tree_map_with_path(self):
        tree = [{i: i for i in range(10)}]
        all_zeros = py_pytree.tree_map_with_path(
            lambda kp, val: val - kp[1].key + kp[0].idx, tree
        )
        self.assertEqual(all_zeros, [dict.fromkeys(range(10), 0)])

    def test_tree_map_with_path_multiple_trees(self):
        @dataclass
        class ACustomPytree:
            x: Any
            y: Any
            z: Any

        tree1 = [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5]
        tree2 = [
            ACustomPytree(
                x=2,
                y={"cin": [2, 2, 2], "bar": 2},
                z="leaf",
            ),
            2,
        ]

        py_pytree.register_pytree_node(
            ACustomPytree,
            flatten_fn=lambda f: ([f.x, f.y], f.z),
            unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
            flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
        )
        from_two_trees = py_pytree.tree_map_with_path(
            lambda kp, a, b: a + b, tree1, tree2
        )
        from_one_tree = py_pytree.tree_map(lambda a: a + 2, tree1)
        self.assertEqual(from_two_trees, from_one_tree)

    @skipIfTorchDynamo("dynamo pytree tracing doesn't work here")
    def test_tree_flatten_with_path_is_leaf(self):
        leaf_dict = {"foo": [(3)]}
        pytree = (["hello", [1, 2], leaf_dict],)
        key_leaves, spec = py_pytree.tree_flatten_with_path(
            pytree, is_leaf=lambda x: isinstance(x, dict)
        )
        self.assertTrue(key_leaves[-1][1] is leaf_dict)

    def test_tree_flatten_with_path_roundtrip(self):
        class ANamedTuple(NamedTuple):
            x: torch.Tensor
            y: int
            z: str

        @dataclass
        class ACustomPytree:
            x: Any
            y: Any
            z: Any

        py_pytree.register_pytree_node(
            ACustomPytree,
            flatten_fn=lambda f: ([f.x, f.y], f.z),
            unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
            flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
        )

        SOME_PYTREES = [
            (None,),
            ["hello", [1, 2], {"foo": [(3)]}],
            [ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")],
            [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5],
        ]
        for pytree in SOME_PYTREES:
            key_leaves, spec = py_pytree.tree_flatten_with_path(pytree)
            actual = py_pytree.tree_unflatten([leaf for _, leaf in key_leaves], spec)
            self.assertEqual(actual, pytree)

    def test_tree_leaves_with_path(self):
        class ANamedTuple(NamedTuple):
            x: torch.Tensor
            y: int
            z: str

        @dataclass
        class ACustomPytree:
            x: Any
            y: Any
            z: Any

        py_pytree.register_pytree_node(
            ACustomPytree,
            flatten_fn=lambda f: ([f.x, f.y], f.z),
            unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
            flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
        )

        SOME_PYTREES = [
            (None,),
            ["hello", [1, 2], {"foo": [(3)]}],
            [ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")],
            [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5],
        ]
        for pytree in SOME_PYTREES:
            flat_out, _ = py_pytree.tree_flatten_with_path(pytree)
            leaves_out = py_pytree.tree_leaves_with_path(pytree)
            self.assertEqual(flat_out, leaves_out)

    def test_key_str(self):
        class ANamedTuple(NamedTuple):
            x: str
            y: int

        tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],)
        flat, _ = py_pytree.tree_flatten_with_path(tree)
        paths = [f"{py_pytree.keystr(kp)}: {val}" for kp, val in flat]
        self.assertEqual(
            paths,
            [
                "[0][0]: hello",
                "[0][1][0]: 1",
                "[0][1][1]: 2",
                "[0][2]['foo'][0]: 3",
                "[0][2]['bar'][0].x: baz",
                "[0][2]['bar'][0].y: 10",
            ],
        )

    @skipIfTorchDynamo("AssertionError in dynamo")
    def test_flatten_flatten_with_key_consistency(self):
        """Check that flatten and flatten_with_key produces consistent leaves/context."""
        reg = py_pytree.SUPPORTED_NODES

        EXAMPLE_TREE = {
            list: [1, 2, 3],
            tuple: (1, 2, 3),
            dict: {"foo": 1, "bar": 2},
            namedtuple: collections.namedtuple("ANamedTuple", ["x", "y"])(1, 2),
            OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]),
            defaultdict: defaultdict(int, {"foo": 1, "bar": 2}),
            deque: deque([1, 2, 3]),
            torch.Size: torch.Size([1, 2, 3]),
            immutable_dict: immutable_dict({"foo": 1, "bar": 2}),
            immutable_list: immutable_list([1, 2, 3]),
        }

        for typ in reg:
            example = EXAMPLE_TREE.get(typ)
            if example is None:
                continue
            flat_with_path, spec1 = py_pytree.tree_flatten_with_path(example)
            flat, spec2 = py_pytree.tree_flatten(example)

            self.assertEqual(flat, [x[1] for x in flat_with_path])
            self.assertEqual(spec1, spec2)

    def test_key_access(self):
        class ANamedTuple(NamedTuple):
            x: str
            y: int

        tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],)
        flat, _ = py_pytree.tree_flatten_with_path(tree)
        for kp, val in flat:
            self.assertEqual(py_pytree.key_get(tree, kp), val)


class TestCxxPytree(TestCase):
    def setUp(self):
        if IS_FBCODE:
            raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")

    def test_treespec_equality(self):
        self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec())

    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
    def test_treespec_repr(self):
        # Check that it looks sane
        pytree = (0, [0, 0, [0]])
        _, spec = cxx_pytree.tree_flatten(pytree)
        self.assertEqual(
            repr(spec),
            ("PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)"),
        )

    @unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
    def test_treespec_repr_dynamo(self):
        # Check that it looks sane
        pytree = (0, [0, 0, [0]])
        _, spec = cxx_pytree.tree_flatten(pytree)
        self.assertExpectedInline(
            repr(spec),
            "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)",
        )

    @parametrize(
        "spec",
        [
            cxx_pytree.tree_structure([]),
            cxx_pytree.tree_structure(()),
            cxx_pytree.tree_structure({}),
            cxx_pytree.tree_structure([0]),
            cxx_pytree.tree_structure([0, 1]),
            cxx_pytree.tree_structure((0, 1, 2)),
            cxx_pytree.tree_structure({"a": 0, "b": 1, "c": 2}),
            cxx_pytree.tree_structure(
                OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
            ),
            cxx_pytree.tree_structure([(0, 1, [2, 3])]),
            cxx_pytree.tree_structure(
                defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}})
            ),
        ],
    )
    def test_pytree_serialize(self, spec):
        self.assertEqual(
            spec,
            cxx_pytree.tree_structure(
                cxx_pytree.tree_unflatten([0] * spec.num_leaves, spec)
            ),
        )

        serialized_spec = cxx_pytree.treespec_dumps(spec)
        self.assertIsInstance(serialized_spec, str)
        self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec))

    def test_pytree_serialize_namedtuple(self):
        spec = cxx_pytree.tree_structure(GlobalPoint(0, 1))

        roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
        self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)

        LocalPoint = namedtuple("LocalPoint", ["x", "y"])
        spec = cxx_pytree.tree_structure(LocalPoint(0, 1))

        roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
        self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)

    def test_pytree_custom_type_serialize(self):
        cxx_pytree.register_pytree_node(
            GlobalDummyType,
            lambda dummy: ([dummy.x, dummy.y], None),
            lambda xs, _: GlobalDummyType(*xs),
            serialized_type_name="GlobalDummyType",
        )
        spec = cxx_pytree.tree_structure(GlobalDummyType(0, 1))
        serialized_spec = cxx_pytree.treespec_dumps(spec)
        roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
        self.assertEqual(roundtrip_spec, spec)

        class LocalDummyType:
            def __init__(self, x, y):
                self.x = x
                self.y = y

        cxx_pytree.register_pytree_node(
            LocalDummyType,
            lambda dummy: ([dummy.x, dummy.y], None),
            lambda xs, _: LocalDummyType(*xs),
            serialized_type_name="LocalDummyType",
        )
        spec = cxx_pytree.tree_structure(LocalDummyType(0, 1))
        serialized_spec = cxx_pytree.treespec_dumps(spec)
        roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
        self.assertEqual(roundtrip_spec, spec)


instantiate_parametrized_tests(TestGenericPytree)
instantiate_parametrized_tests(TestPythonPytree)
instantiate_parametrized_tests(TestCxxPytree)


if __name__ == "__main__":
    run_tests()
